from langchain_community.utils.math import cosine_similarity
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_openai import ChatOpenAI

from LongChainDemo01.local_knolege_robot import embeddings

physics_template = """You are a very smart physics professor. \
You are great at answering questions about physics in a concise and easy to understand manner. \
When you don't know the answer to a question you admit that you don't know.

Here is a question:
{query}"""

math_template = """You are a very good mathematician. You are great at answering math questions. \
You are so good because you are able to break down hard problems into their component parts, \
answer the component parts, and then put them together to answer the broader question.

Here is a question:
{query}"""

embeddings=HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
prompt_templates = [physics_template, math_template]
prompt_embeddings = embeddings.embed_documents(prompt_templates)
def prompt_router(input):
    query_embedding = embeddings.embed_query(input["query"])
    similarity = cosine_similarity([query_embedding], prompt_embeddings)[0]
    most_similar = prompt_templates[similarity.argmax()]
    print("Using MATH" if most_similar == math_template else "Using PHYSICS")
    return PromptTemplate.from_template(most_similar)

chat_model = ChatOpenAI(
    openai_api_key="key",
    openai_api_base="https://api.moonshot.cn/v1",
    model="moonshot-v1-8k",
    temperature=0,
    request_timeout=60,
    max_retries=3,
)

chain = (
    {"query": RunnablePassthrough()}
    | RunnableLambda(prompt_router)
    | chat_model
    | StrOutputParser()
)
print(chain.invoke("What's a black hole"))